"""
部分旋转 [:,:,20]
对 loss的研究，extraction_rate, c_init,c_fin
"""
import argparse
from collections import defaultdict
from torch import optim
from torch.utils.data import DataLoader
import wandb
from tqdm import trange

from disvae.models.anneal import *
from disvae.models.losses import _kl_normal_loss, _reconstruction_loss
from disvae.models.vae import FVAE
from disvae.models.encoders import get_encoder
from disvae.models.decoders import get_decoder
from utils.visualize import *
from exps.patterns import data_generator
import numpy as np
from utils.helpers import set_seed

hyperparameter_defaults = dict(
    batch_size=64,
    learning_rate=0.001,
    epochs=101,
    beta=80,
    dimension=6,
    img_id=8,
    random_seed=210
)

parser = argparse.ArgumentParser()
for key, value in hyperparameter_defaults.items():
    parser.add_argument(f'--{key}', default=value, type=type(value))
args = parser.parse_args()

wandb.init(project="exps", config=args, group='extraction', notes=__doc__)

config = wandb.config
set_seed(config.random_seed)


def get_preds(model, dl):
    model.eval()
    preds = []
    targets = []
    for img, label in dl:
        with torch.no_grad():
            img = img.view(-1, 1, 64, 64).cuda()
            mu, _ = model.encoder(img)
            preds.append(mu)
            targets.append(label)
    preds, targets = torch.cat(preds), torch.cat(targets)
    model.train()
    return preds, targets


def _kl_normal_loss(mean, logvar, storer=None):
    latent_dim = mean.size(1)
    # batch mean of kl for each latent dimension
    latent_kl = 0.5 * (-1 - logvar + mean.pow(2) + logvar.exp()).mean(dim=0)

    total_kl = (latent_kl).sum()
    if storer is not None:
        storer['KL_loss'].append(total_kl.item())
        for i in range(latent_dim):
            storer['kl_' + str(i)].append(latent_kl[i].item())

    return latent_kl


def get_p(step, len):
    if step < 0 or step > len:
        return 0
    t = step / len
    # return 1-t
    if t < 0.5:
        return 1
    else:
        return 2 - t * 2


def train(dl, model, epochs, beta, phase):
    opt = optim.Adam(model.parameters(), config.learning_rate, weight_decay=0)
    C = 3
    anneal = Monotonic(len(dl) * epochs, beta, beta, False)
    # p = torch.zeros((1,config.dimension)).cuda()
    # p[0,phase]=1
    g = 2
    # model.p[0]=torch.linspace(1,0,6).cuda()
    for e in model.encoders:
        e[1].p = 0.0
    model.encoders[phase][1].p = 1
    model.phase = phase

    for e in trange(epochs):
        storer = defaultdict(list)
        for img, label in dl:
            img = img.view(-1, 1, 64, 64).cuda()

            recon_batch, latent_dist, latent_sample = model(img)

            rec_loss = _reconstruction_loss(img, recon_batch,
                                            storer=storer,
                                            distribution='bernoulli')

            kl_loss = _kl_normal_loss(*latent_dist, storer)
            KL = kl_loss.sum()
            c = anneal.next()
            loss = rec_loss + KL * c

            opt.zero_grad()
            # encoder_opt.zero_grad()
            loss.backward()
            opt.step()
            # encoder_opt.step()
            storer['loss'].append(loss.item())
            storer['c'].append(c)

        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)

        kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' in k])
        _, index = kl_loss.sort(descending=True)

        if (e + 1) % 20 == 0:
            storer['recon'] = wandb.Image(recon_batch[0, 0])
            storer['img'] = wandb.Image(img[0, 0])

            for i in range(latent_sample.size(1)):
                mu, logvar = latent_dist
                storer[f'mu_{i}'] = wandb.Histogram(mu[:, i].data.cpu())
                storer[f'sigma_{i}'] = wandb.Histogram((0.5 * logvar[:, i].exp()).data.cpu())

        if (e + 1) % 20 == 0 and len(index) >= 3:
            preds, targets = get_preds(model, dl)
            points = preds[:1000, index[:3]].cpu()
            colors = targets[:1000]
            fig, axes = plt.subplots(3, 3)
            for i in range(3):  # factor
                for j in range(3):  # variable
                    axes[i, j].scatter(points[:, j].numpy(), colors[:, i].numpy(), s=0.2)
                    axes[i, j].set_title(f'{i},{index[j]}')
            storer['correlated'] = wandb.Image(fig)

            fig = plot_projection(imgs[10, ::8, ::4].cuda(), model, dim=index[:2])
            storer['projection_ay'] = wandb.Image(fig)

        plt.close()
        storer['epoch'] = e
        wandb.log(storer, sync=False)
    model.encoders[phase][0].p = 0
    return storer


# generate data
img_id = config.img_id
img = torch.load(f'patterns/{img_id}.pat')

imgs, labels = data_generator.gen_rotation(img, img.shape)

epochs = config.epochs
dim = config.dimension

dataset = imgs[[0, 10], :, :, ].reshape(-1, 1, 64, 64)
nlabels = labels[[0, 10], :, ].reshape(-1, 3)
dl = DataLoader(list(zip(dataset, nlabels)), pin_memory=True,
                num_workers=0, batch_size=config.batch_size, shuffle=True)

vae = FVAE((1, 64, 64), get_encoder('Burgess'), get_decoder('Burgess'), dim, 5)
# for e in vae.encoders:
#     e.cuda();
#     e.train()
#
# encoders = nn.Sequential(*vae.encoders)
# vae.active = encoders

# wandb.watch(vae.encoders[0], 'all', log_freq=10)
vae.cuda();
vae.train()

storer = train(dl, vae, epochs, config.beta, 0)
# torch.save(encoders.state_dict(),'encoders.pt')
# torch.save(vae.state_dict(),'tmp.pt')
# vae.load_state_dict(torch.load('tmp.pt'))
# encoders.load_state_dict(torch.load('encoders.pt'))
# storer = train(dl, vae, 21 , 20, 1)

vae.eval()
#
kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' in k])
_, index = kl_loss.sort(descending=True)
print(index)

preds, targets = get_preds(vae, dl)
points = preds[:1000, index[:3]].cpu()
colors = targets[:1000]
fig, axes = plt.subplots(2, 3)
for i in range(2):  # factor
    for j in range(3):  # variable
        axes[i, j].scatter(points[:, j].numpy(), colors[:, i + 1].numpy(), s=0.2)
        axes[i, j].set_title(f'{i + 1},{index[j]}')
storer['correlated'] = wandb.Image(fig)

# fig = plot_projection(imgs[::4, ::8, 20].cuda(), vae, dim=index[:2])
# storer['projection_ay'] = wandb.Image(fig)

point_cloud = torch.cat([points, colors * torch.tensor([6, 6, 6])], dim=1).numpy()
wandb.log({'point_cloud': wandb.Object3D(point_cloud)})

vae.eval()
vae.cpu()
#
# del config

# fig = plot_reconstruct(dataset, (3, 2), vae)
# wandb.log({'reconstruction': wandb.Image(fig)})
#
fig = plt_sample_traversal(None, vae, 7, index, r=3)
storer['traversal'] = wandb.Image(fig)
wandb.log(storer)

# fig = plot_projection(imgs[::4, ::4, 20], vae, dim=index[:2])
# wandb.log({'projection_ay': wandb.Image(fig)})

plt.show()
# wandb.join()
